/*
 * Decompiled with CFR 0.152.
 */
package icyllis.arc3d.compiler.tree;

import icyllis.arc3d.compiler.ConstantFolder;
import icyllis.arc3d.compiler.Context;
import icyllis.arc3d.compiler.Position;
import icyllis.arc3d.compiler.analysis.Analysis;
import icyllis.arc3d.compiler.tree.ConstructorCompound;
import icyllis.arc3d.compiler.tree.ConstructorCompoundCast;
import icyllis.arc3d.compiler.tree.ConstructorScalarCast;
import icyllis.arc3d.compiler.tree.ConstructorVectorSplat;
import icyllis.arc3d.compiler.tree.Expression;
import icyllis.arc3d.compiler.tree.Literal;
import icyllis.arc3d.compiler.tree.Node;
import icyllis.arc3d.compiler.tree.TreeVisitor;
import icyllis.arc3d.compiler.tree.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Objects;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

public final class Swizzle
extends Expression {
    public static final byte X = 0;
    public static final byte Y = 1;
    public static final byte Z = 2;
    public static final byte W = 3;
    public static final byte R = 4;
    public static final byte G = 5;
    public static final byte B = 6;
    public static final byte A = 7;
    public static final byte S = 8;
    public static final byte T = 9;
    public static final byte P = 10;
    public static final byte Q = 11;
    public static final byte ZERO = 16;
    public static final byte ONE = 17;
    private final Expression mBase;
    private final byte[] mComponents;

    private Swizzle(int position, Type type, Expression base, byte[] components) {
        super(position, type);
        assert (components.length >= 1 && components.length <= 4);
        this.mBase = base;
        this.mComponents = components;
    }

    private static boolean validateNameSet(byte[] components) {
        int set = -1;
        block6: for (byte component : components) {
            int newSet;
            switch (component) {
                case 0: 
                case 1: 
                case 2: 
                case 3: {
                    newSet = 0;
                    break;
                }
                case 4: 
                case 5: 
                case 6: 
                case 7: {
                    newSet = 1;
                    break;
                }
                case 8: 
                case 9: 
                case 10: 
                case 11: {
                    newSet = 2;
                    break;
                }
                case 16: 
                case 17: {
                    continue block6;
                }
                default: {
                    return false;
                }
            }
            if (set == -1) {
                set = newSet;
                continue;
            }
            if (set == newSet) continue;
            return false;
        }
        return true;
    }

    /*
     * WARNING - void declaration
     */
    @Nullable
    private static Expression optimizeSwizzle(@Nonnull Context context, int pos, @Nonnull ConstructorCompound base, byte[] components, int numComponents) {
        void var15_25;
        Expression[] baseArguments = base.getArguments();
        Type exprType = base.getType();
        Type componentType = exprType.getComponentType();
        int numConstructorArgs = exprType.getRows();
        short[] argMap = new short[4];
        int writeIdx = 0;
        for (int argIdx = 0; argIdx < baseArguments.length; ++argIdx) {
            Expression arg = baseArguments[argIdx];
            Type argType = arg.getType();
            if (!argType.isScalar() && !argType.isVector()) {
                return null;
            }
            int argComps = argType.getComponents();
            for (int i = 0; i < argComps; ++i) {
                argMap[writeIdx] = (short)(argIdx | i << 8);
                ++writeIdx;
            }
        }
        assert (writeIdx == numConstructorArgs);
        byte[] exprUsed = new byte[4];
        for (Object object : (Expression)components) {
            byte by = (byte)argMap[object];
            exprUsed[by] = (byte)(exprUsed[by] + 1);
        }
        for (int index = 0; index < numConstructorArgs; ++index) {
            byte constructorArgIndex = (byte)argMap[index];
            Expression baseArg = baseArguments[constructorArgIndex];
            if (exprUsed[constructorArgIndex] > 1 && !Analysis.isTrivialExpression(baseArg)) {
                return null;
            }
            if (exprUsed[constructorArgIndex] == 1 || !Analysis.hasSideEffects(baseArg)) continue;
            return null;
        }
        class ReorderedArgument {
            final byte mArgIndex;
            final byte[] mComponents = new byte[4];
            byte mNumComponents = 0;

            ReorderedArgument(byte argIndex) {
                this.mArgIndex = argIndex;
            }
        }
        ArrayList<ReorderedArgument> reorderedArgs = new ArrayList<ReorderedArgument>(4);
        byte[] constructorArgIndex = components;
        int baseArg = constructorArgIndex.length;
        boolean bl = false;
        while (var15_25 < baseArg) {
            byte c2 = constructorArgIndex[var15_25];
            short argument = argMap[c2];
            byte argumentIndex = (byte)argument;
            byte argumentComponent = (byte)(argument >> 8);
            Expression baseArg2 = baseArguments[argumentIndex];
            if (baseArg2.getType().isScalar()) {
                assert (argumentComponent == 0);
                reorderedArgs.add(new ReorderedArgument(argumentIndex));
            } else {
                assert (baseArg2.getType().isVector());
                assert (argumentComponent < baseArg2.getType().getRows());
                if (reorderedArgs.isEmpty() || ((ReorderedArgument)reorderedArgs.get((int)(reorderedArgs.size() - 1))).mArgIndex != argumentIndex) {
                    ReorderedArgument toAdd = new ReorderedArgument(argumentIndex);
                    byte by = toAdd.mNumComponents;
                    toAdd.mNumComponents = (byte)(by + 1);
                    toAdd.mComponents[by] = argumentComponent;
                    reorderedArgs.add(toAdd);
                } else {
                    ReorderedArgument last = (ReorderedArgument)reorderedArgs.get(reorderedArgs.size() - 1);
                    assert (last.mNumComponents != 0);
                    byte by = last.mNumComponents;
                    last.mNumComponents = (byte)(by + 1);
                    last.mComponents[by] = argumentComponent;
                }
            }
            ++var15_25;
        }
        Expression[] newArgs = new Expression[numComponents];
        for (int i = 0; i < reorderedArgs.size(); ++i) {
            ReorderedArgument reorderedArgument = (ReorderedArgument)reorderedArgs.get(i);
            Expression newArg = baseArguments[reorderedArgument.mArgIndex].clone();
            newArgs[i] = reorderedArgument.mNumComponents == 0 ? newArg : Swizzle.make(context, pos, newArg, reorderedArgument.mComponents, reorderedArgument.mNumComponents);
        }
        return ConstructorCompound.make(context, pos, componentType.toVector(context, numComponents), newArgs);
    }

    @Nullable
    public static Expression convert(@Nonnull Context context, int position, @Nonnull Expression base, int maskPosition, @Nonnull String maskString) {
        if (maskString.length() > 4) {
            context.error(maskPosition, "too many components in swizzle mask");
            return null;
        }
        byte[] inComponents = new byte[maskString.length()];
        for (int i = 0; i < maskString.length(); ++i) {
            int c;
            char field = maskString.charAt(i);
            switch (field) {
                case 'x': {
                    c = 0;
                    break;
                }
                case 'r': {
                    c = 4;
                    break;
                }
                case 's': {
                    c = 8;
                    break;
                }
                case 'y': {
                    c = 1;
                    break;
                }
                case 'g': {
                    c = 5;
                    break;
                }
                case 't': {
                    c = 9;
                    break;
                }
                case 'z': {
                    c = 2;
                    break;
                }
                case 'b': {
                    c = 6;
                    break;
                }
                case 'p': {
                    c = 10;
                    break;
                }
                case 'w': {
                    c = 3;
                    break;
                }
                case 'a': {
                    c = 7;
                    break;
                }
                case 'q': {
                    c = 11;
                    break;
                }
                case '0': {
                    c = 16;
                    break;
                }
                case '1': {
                    c = 17;
                    break;
                }
                default: {
                    int offset = Position.getStartOffset(maskPosition) + i;
                    context.error(Position.range(offset, offset + 1), String.format("invalid swizzle component '%c'", Character.valueOf(field)));
                    return null;
                }
            }
            inComponents[i] = c;
        }
        if (!Swizzle.validateNameSet(inComponents)) {
            context.error(maskPosition, "swizzle components '" + maskString + "' do not come from the same name set");
            return null;
        }
        Type baseType = base.getType();
        if (!baseType.isVector() && !baseType.isScalar()) {
            context.error(position, "cannot swizzle value of type '" + baseType + "'");
            return null;
        }
        byte[] maskComponents = new byte[inComponents.length];
        int numComponents = 0;
        boolean foundXYZW = false;
        block28: for (int i = 0; i < inComponents.length; ++i) {
            byte c = inComponents[i];
            switch (c) {
                case 16: 
                case 17: {
                    continue block28;
                }
                case 0: 
                case 4: 
                case 8: {
                    foundXYZW = true;
                    int n = numComponents;
                    numComponents = (byte)(numComponents + 1);
                    maskComponents[n] = 0;
                    continue block28;
                }
                case 1: 
                case 5: 
                case 9: {
                    foundXYZW = true;
                    if (baseType.getRows() >= 2) {
                        int n = numComponents;
                        numComponents = (byte)(numComponents + 1);
                        maskComponents[n] = 1;
                        continue block28;
                    }
                }
                case 2: 
                case 6: 
                case 10: {
                    foundXYZW = true;
                    if (baseType.getRows() >= 3) {
                        int n = numComponents;
                        numComponents = (byte)(numComponents + 1);
                        maskComponents[n] = 2;
                        continue block28;
                    }
                }
                case 3: 
                case 7: 
                case 11: {
                    foundXYZW = true;
                    if (baseType.getRows() >= 4) {
                        int n = numComponents;
                        numComponents = (byte)(numComponents + 1);
                        maskComponents[n] = 3;
                        continue block28;
                    }
                }
                default: {
                    int offset = Position.getStartOffset(maskPosition) + i;
                    context.error(Position.range(offset, offset + 1), String.format("swizzle component '%c' is out of range for type '%s'", Character.valueOf(maskString.charAt(i)), baseType));
                    return null;
                }
            }
        }
        if (!foundXYZW) {
            context.error(maskPosition, "swizzle must refer to base expression");
            return null;
        }
        if ((base = baseType.coerceExpression(context, base)) == null) {
            return null;
        }
        Expression expr = Swizzle.make(context, position, base, maskComponents, numComponents);
        if (numComponents == inComponents.length) {
            return expr;
        }
        ArrayList<Expression> constructorArgs = new ArrayList<Expression>(3);
        constructorArgs.add(expr);
        Type scalarType = baseType.getComponentType();
        int maskFieldIdx = 0;
        int constantFieldIdx = numComponents;
        int constantZeroIdx = -1;
        int constantOneIdx = -1;
        numComponents = 0;
        block29: for (byte component : inComponents) {
            switch (component) {
                case 16: {
                    if (constantZeroIdx == -1) {
                        constructorArgs.add(Literal.make(position, 0.0, scalarType));
                        int n = constantFieldIdx;
                        constantFieldIdx = (byte)(constantFieldIdx + 1);
                        constantZeroIdx = n;
                    }
                    int n = numComponents;
                    numComponents = (byte)(numComponents + 1);
                    maskComponents[n] = constantZeroIdx;
                    continue block29;
                }
                case 17: {
                    if (constantOneIdx == -1) {
                        constructorArgs.add(Literal.make(position, 1.0, scalarType));
                        int n = constantFieldIdx;
                        constantFieldIdx = (byte)(constantFieldIdx + 1);
                        constantOneIdx = n;
                    }
                    int n = numComponents;
                    numComponents = (byte)(numComponents + 1);
                    maskComponents[n] = constantOneIdx;
                    continue block29;
                }
                default: {
                    int n = numComponents;
                    numComponents = (byte)(numComponents + 1);
                    int n2 = maskFieldIdx;
                    maskFieldIdx = (byte)(maskFieldIdx + 1);
                    maskComponents[n] = n2;
                }
            }
        }
        expr = ConstructorCompound.make(context, position, scalarType.toVector(context, constantFieldIdx), constructorArgs.toArray(new Expression[0]));
        return Swizzle.make(context, position, expr, maskComponents, numComponents);
    }

    @Nonnull
    public static Expression make(@Nonnull Context context, int position, @Nonnull Expression base, byte[] components, int numComponents) {
        ConstructorCompound ctor;
        Expression replacement;
        Type baseType = base.getType();
        assert (baseType.isVector() || baseType.isScalar());
        for (int i = 0; i < numComponents; ++i) {
            byte component = components[i];
            assert (component == 0 || component == 1 || component == 2 || component == 3);
        }
        assert (numComponents >= 1 && numComponents <= 4);
        if (baseType.isScalar()) {
            return ConstructorVectorSplat.make(position, baseType.toVector(context, numComponents), base);
        }
        if (numComponents == baseType.getRows()) {
            boolean identity = true;
            for (int i = 0; i < numComponents; ++i) {
                if (components[i] == i) continue;
                identity = false;
                break;
            }
            if (identity) {
                base.mPosition = position;
                return base;
            }
        }
        if (base instanceof Swizzle) {
            Swizzle b = (Swizzle)base;
            byte[] combined = new byte[numComponents];
            for (int i = 0; i < numComponents; ++i) {
                byte c = components[i];
                combined[i] = b.getComponents()[c];
            }
            return Swizzle.make(context, position, b.getBase(), combined, numComponents);
        }
        Expression value = ConstantFolder.getConstantValueForVariable(base);
        if (value instanceof ConstructorVectorSplat) {
            ConstructorVectorSplat ctor2 = (ConstructorVectorSplat)value;
            Type ctorType = ctor2.getComponentType().toVector(context, numComponents);
            return ConstructorVectorSplat.make(position, ctorType, ctor2.getArgument().clone());
        }
        if (value instanceof ConstructorCompoundCast) {
            ConstructorCompoundCast ctor3 = (ConstructorCompoundCast)value;
            Type ctorType = ctor3.getComponentType().toVector(context, numComponents);
            Expression swizzled = Swizzle.make(context, position, ctor3.getArgument().clone(), components, numComponents);
            Objects.requireNonNull(swizzled);
            return ctorType.getRows() > 1 ? ConstructorCompoundCast.make(position, ctorType, swizzled) : ConstructorScalarCast.make(context, position, ctorType, swizzled);
        }
        if (value.getKind() == Node.ExpressionKind.CONSTRUCTOR_COMPOUND && (replacement = Swizzle.optimizeSwizzle(context, position, ctor = (ConstructorCompound)value, components, numComponents)) != null) {
            return replacement;
        }
        return new Swizzle(position, baseType.getComponentType().toVector(context, numComponents), base, Arrays.copyOf(components, numComponents));
    }

    @Override
    public Node.ExpressionKind getKind() {
        return Node.ExpressionKind.SWIZZLE;
    }

    @Override
    public boolean accept(@Nonnull TreeVisitor visitor) {
        if (visitor.visitSwizzle(this)) {
            return true;
        }
        return this.mBase != null && this.mBase.accept(visitor);
    }

    public Expression getBase() {
        return this.mBase;
    }

    public byte[] getComponents() {
        return this.mComponents;
    }

    @Override
    @Nonnull
    public Expression clone(int position) {
        return new Swizzle(position, this.getType(), this.mBase.clone(), this.mComponents);
    }

    @Override
    @Nonnull
    public String toString(int parentPrecedence) {
        StringBuilder result = new StringBuilder(this.mBase.toString(2));
        result.append('.');
        for (byte component : this.mComponents) {
            result.append(switch (component) {
                case 0 -> 'x';
                case 1 -> 'y';
                case 2 -> 'z';
                case 3 -> 'w';
                default -> throw new IllegalStateException();
            });
        }
        return result.toString();
    }
}

